Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make influence_fn a higher-order Functional #492

Merged
merged 6 commits into from
Jan 11, 2024

Conversation

eb8680
Copy link
Contributor

@eb8680 eb8680 commented Jan 10, 2024

Addresses #479, #480

This PR changes the type signature of chirho.robust.ops.influence_fn to make it more compositional, following the suggestions in #479 and #480. In particular, influence_fn is now a higher-order functional that maps a target functional to another functional:

def influence_fn(functional: Functional[P, S], *points: Point[T]) -> Functional[P, S]: ...

It also generalizes the Functional interface (and therefore those of influence_fn and linearize) to arbitrary arities (but still only implements the unary case):

class Functional(Protocol[P, S]):
    def __call__(self, __model: Callable[P, Any], *models: Callable[P, Any]) -> Callable[P, S]: ...

Bundling these interface changes together in one PR will hopefully ensure that there will only be one set of merge conflicts associated with #479 and #480 to resolve in other open PRs targeting staging-robust.

@eb8680 eb8680 added enhancement New feature or request refactor status:WIP Work-in-progress not yet ready for review module:robust labels Jan 10, 2024
@eb8680 eb8680 added status:awaiting review Awaiting response from reviewer and removed status:WIP Work-in-progress not yet ready for review labels Jan 11, 2024
Copy link
Contributor

@agrawalraj agrawalraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@agrawalraj agrawalraj merged commit c4346c8 into staging-robust Jan 11, 2024
6 checks passed
@agrawalraj agrawalraj deleted the eb-robust-make-influence-functional branch January 11, 2024 19:19
eb8680 added a commit that referenced this pull request Jan 12, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 19, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 24, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

* gets fd integrated into experiment exec and running.

* adds perturbable normal model to statics listing

* switches back to mean not mu

* lines up mean mu loc naming correctly.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 25, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* progress on tmle

* placeholder test

* more progress on TMLE

* more progress, still need to refactor

* progress on variational tmle

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* more progress on tmle

* really resolved merge conflicts

* more progress, still a bit stuck on functional tensors

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* update tmle signature and remove unused imports

* make tmle signature consistent with one-step

* lint

* progress

* pair program still issues

* debugging still

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* more attempts

* added scipy optimize :(

* more progress

* more progress

* working end-to-end tmle

* remove comment

* revert changes

* update tests and defaults

* lint

* playing with tmle performance

* more tweaks

* pulled out influence computation and changed loss

* finally got tmle working

* revert test

* added placeholder for passing in influence_fn_estimator

* analytic influence for example

* lint

* fix estimator

* fix tests

* lint

* notebook

* bump notebook

* use torchopt

* add torchopt

* rerun tmle notebook with effect = 1

* lint

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
SamWitty added a commit that referenced this pull request Jul 18, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* old dr notebook that got deleted from wrong merge

* added missing fig

* redid notebook with new interface

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* updated labels

* updated w/ new interface but only 1 data sim

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* uncommitted changes

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* kernel speedup

* before switching to krr formulation

* uncommitted changes

* updated w/ new interface; removed GP section for now

* runs but not matching

* still not working, going to make major changes

* remove debug script

* remove file

* remove file

* add

* update interfaces

* finished running

* outline

* remove outline for now

* simplify notebook

* merge

---------

Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request module:robust refactor status:awaiting review Awaiting response from reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants